#!/usr/bin/python
import os
def generateShape(rootDir):
    shapeDir = os.path.join(rootDir, "source", "shape")
    shapeRenderDir = os.path.join(rootDir, "source", "shape", "render")
    shapeLists = []
    renderShape = []
    transformerFuseShape = []
    def collectFile(f):
        if os.path.isdir(f):
            return
        if ".hpp" in f:
            return
        with open(f) as fileC:
            c = fileC.read().split('\n')
            c = list(filter(lambda l:l.find('REGISTER_SHAPE')>=0, c))
            for l in c:
                if l.find('REGISTER_SHAPE(')>=0:
                    l = l.replace("REGISTER_SHAPE(", "")
                    l = l.split(')')[0]
                    l = l.replace(' ', "")
                    l = l.split(',')
                    func = '___' + l[0] + '__'+l[1]+"__"
                    shapeLists.append(func)
                elif l.find('REGISTER_SHAPE_OLD(')>=0:
                    l = l.replace("REGISTER_SHAPE_OLD(", "")
                    l = l.split(')')[0]
                    l = l.replace(' ', "")
                    l = l.split(',')
                    func = '___' + l[0] + '__'+l[1]+"__"
                    shapeLists.append(func)
                elif l.find('REGISTER_SHAPE_INPUTS(') >= 0:
                    l = l.replace("REGISTER_SHAPE_INPUTS(", "")
                    l = l.split(')')[0]
                    l = l.replace(' ', "")
                    l = l.split(',')
                    func = '___' + l[0] + '__'+l[1]+"__"
                    shapeLists.append(func)
                elif l.find('REGISTER_SHAPE_INPUTS_RENDER(') >= 0:
                    l = l.replace("REGISTER_SHAPE_INPUTS_RENDER(", "")
                    l = l.split(')')[0]
                    l = l.replace(' ', "")
                    l = l.split(',')
                    func = '___' + l[0] + '__'+l[1]+"__"
                    renderShape.append(func)
                elif l.find('REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE') >= 0:
                    l = l.replace("REGISTER_SHAPE_INPUTS_TRANSFORMER_FUSE(", "")
                    l = l.split(')')[0]
                    l = l.replace(' ', "")
                    l = l.split(',')
                    func = '___' + l[0] + '__'+l[1]+"__"
                    transformerFuseShape.append(func)
    shapeRegFile = os.path.join(shapeDir, "ShapeRegister.cpp")
    print(shapeRegFile)
    for fi in os.listdir(shapeDir):
        f = os.path.join(shapeDir, fi)
        collectFile(f)
    if os.path.isdir(shapeRenderDir):
        for fi in os.listdir(shapeRenderDir):
            f = os.path.join(shapeRenderDir, fi)
            collectFile(f)

    with open(shapeRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
    with open(shapeRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('namespace MNN {\n')
        for l in shapeLists:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_RENDER' + '\n')
        for l in renderShape:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerFuseShape:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        f.write('void registerShapeOps() {\n')
        for l in shapeLists:
            f.write(l+'();\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_RENDER' + '\n')
        for l in renderShape:
            f.write(l+'();\n')
        f.write('#endif\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerFuseShape:
            f.write(l+'();\n')
        f.write('#endif\n')
        f.write("}\n}\n")
    return

def generateCPUFile(rootDir):
    cpuDir = os.path.join(rootDir, "source", "backend", "cpu")
    cpuRenderDir = os.path.join(rootDir, "source", "backend", "cpu", "render")
    cpuRegFile = os.path.join(cpuDir, "CPUOPRegister.cpp")
    funcNames = []
    renderNames = []
    transformerNamse = []
    def collectFile(fileNames, dirname):
        for fi in fileNames:
            f = os.path.join(dirname, fi)
            if os.path.isdir(f):
                continue
            with open(f) as fileC:
                c = fileC.read().split('\n')
                c = list(filter(lambda l:l.find('REGISTER_CPU_OP_CREATOR')>=0, c))
                c = list(filter(lambda l:l.find('OpType')>=0, c))
                for lo in c:
                    l = lo.split('(')[1]
                    l = l.split(')')[0]
                    l = l.replace(' ', '')
                    l = l.split(',')
                    funcName = '___' + l[0] + '__' + l[1] + '__'
                    if lo.find('REGISTER_CPU_OP_CREATOR_RENDER') >=0:
                        renderNames.append(funcName)
                    elif lo.find('REGISTER_CPU_OP_CREATOR_TRANSFORMER') >= 0:
                        transformerNamse.append(funcName)
                    else:
                        funcNames.append(funcName)
    fileNames = os.listdir(cpuDir)
    print(fileNames)
    collectFile(fileNames, cpuDir)
    if os.path.isdir(cpuRenderDir):
        fileNames = os.listdir(cpuRenderDir)
        collectFile(fileNames, cpuRenderDir)

    with open(cpuRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('namespace MNN {\n')
        for l in funcNames:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_RENDER' + '\n')
        for l in renderNames:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerNamse:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        f.write('void registerCPUOps() {\n')
        for l in funcNames:
            f.write(l+'();\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_RENDER' + '\n')
        for l in renderNames:
            f.write(l+'();\n')
        f.write('#endif\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerNamse:
            f.write(l+'();\n')
        f.write('#endif\n')
        f.write("}\n}\n")

def generateOPENCLFile(rootDir):
    openclDir = os.path.join(rootDir, "source", "backend", "opencl")
    openclBufferDir = os.path.join(rootDir, "source", "backend", "opencl", "execution", "buffer")
    openclImageDir = os.path.join(rootDir, "source", "backend", "opencl", "execution", "image")
    openclRegFile = os.path.join(openclDir, "core", "OpenCLOPRegister.cpp")
    opNamesImage = []
    opNamesBuffer = []
    transformerNamse = []
    def collectFile(fileNames, dirname):
        end = '__IMAGE__'
        if "buffer" in dirname:
            end = '__BUFFER__'
        for fi in fileNames:
            f = os.path.join(dirname, fi)
            if os.path.isdir(f):
                continue
            with open(f) as fileC:
                c = fileC.read().split('\n')
                c = list(filter(lambda l:l.find('REGISTER_OPENCL_OP_CREATOR')>=0, c))
                c = list(filter(lambda l:l.find('OpType')>=0, c))
                for lo in c:
                    l = lo.split('(')[1]
                    l = l.split(')')[0]
                    l = l.replace(' ', '')
                    l = l.split(',')
                    funcName = '___' + 'OpenCL' + l[0] + '__' + l[1] + end
                    if lo.find('REGISTER_OPENCL_OP_CREATOR_TRANSFORMER') >=0:
                        transformerNamse.append(funcName)
                    elif end == '__IMAGE__':
                        opNamesImage.append(funcName)
                    else:
                        opNamesBuffer.append(funcName)

    bufferFileNames = os.listdir(openclBufferDir)
    print(bufferFileNames)
    collectFile(bufferFileNames, openclBufferDir)

    imageFileNames = os.listdir(openclImageDir)
    print(imageFileNames)
    collectFile(imageFileNames, openclImageDir)

    with open(openclRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('#ifndef MNN_OPENCL_SEP_BUILD\n')
        f.write('namespace MNN {\n')
        f.write('namespace OpenCL {\n')
        f.write('#ifndef ' + 'MNN_OPENCL_BUFFER_CLOSED' + '\n')
        for l in opNamesBuffer:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        for l in opNamesImage:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerNamse:
            f.write("extern void " + l + '();\n')
        f.write('#endif\n')
        f.write('void registerOpenCLOps() {\n')
        f.write('#ifndef ' + 'MNN_OPENCL_BUFFER_CLOSED' + '\n')
        for l in opNamesBuffer:
            f.write(l + '();\n')
        f.write('#endif\n')
        for l in opNamesImage:
            f.write(l + '();\n')
        f.write('\n')
        f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n')
        for l in transformerNamse:
            f.write(l+'();\n')
        f.write('#endif\n')
        f.write("}\n}\n}\n")
        f.write('#endif\n')

def generateGeoFile(rootDir):
    geoDir = os.path.join(rootDir, "source", "geometry")
    regFile = os.path.join(geoDir, "GeometryOPRegister.cpp")
    fileNames = os.listdir(geoDir)
    print(fileNames)
    if len(fileNames) <= 1:
        # Error dirs
        return
    funcNames = []
    for fi in fileNames:
        if ".hpp" in fi:
            continue
        f = os.path.join(geoDir, fi)
        if os.path.isdir(f):
            continue
        with open(f) as fileC:
            c = fileC.read().split('\n')
            c = list(filter(lambda l:l.find('REGISTER_GEOMETRY')>=0, c))
            for l in c:
                l = l.split('(')[1]
                l = l.split(')')[0]
                l = l.replace(' ', '')
                l = l.split(',')
                funcName = '___' + l[0] + '__' + l[1] + '__'
                funcNames.append(funcName)

    with open(regFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('#include \"geometry/GeometryComputer.hpp\"\n')
        f.write('namespace MNN {\n')
        for l in funcNames:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('void registerGeometryOps() {\n')
        for l in funcNames:
            f.write(l+'();\n')
        f.write("}\n}\n")

def generateCoreMLFile(rootDir):
    coremlDir = os.path.join(rootDir, "source", "backend", "coreml")
    coremlExeDir = os.path.join(coremlDir, "execution")
    coremlRegFile = os.path.join(coremlDir, "backend", "CoreMLOPRegister.cpp")
    fileNames = os.listdir(coremlExeDir)
    print(fileNames)
    if len(fileNames) <= 1:
        # Error dirs
        return
    funcNames = []
    for fi in fileNames:
        f = os.path.join(coremlExeDir, fi)
        if os.path.isdir(f):
            continue
        with open(f) as fileC:
            c = fileC.read().split('\n')
            c = list(filter(lambda l:l.find('REGISTER_COREML_OP_CREATOR')>=0, c))
            c = list(filter(lambda l:l.find('OpType')>=0, c))
            for l in c:
                l = l.split('(')[1]
                l = l.split(')')[0]
                l = l.replace(' ', '')
                l = l.split(',')
                funcName = '___' + l[0] + '__' + l[1] + '__'
                funcNames.append(funcName)
    with open(coremlRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('namespace MNN {\n')
        for l in funcNames:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('void registerCoreMLOps() {\n')
        for l in funcNames:
            f.write(l+'();\n')
        f.write("}\n}\n")

def generateNNAPIFile(rootDir):
    coremlDir = os.path.join(rootDir, "source", "backend", "nnapi")
    coremlExeDir = os.path.join(coremlDir, "execution")
    coremlRegFile = os.path.join(coremlDir, "backend", "NNAPIOPRegister.cpp")
    fileNames = os.listdir(coremlExeDir)
    print(fileNames)
    if len(fileNames) <= 1:
        # Error dirs
        return
    funcNames = []
    for fi in fileNames:
        f = os.path.join(coremlExeDir, fi)
        if os.path.isdir(f):
            continue
        with open(f) as fileC:
            c = fileC.read().split('\n')
            c = list(filter(lambda l:l.find('REGISTER_NNAPI_OP_CREATOR')>=0, c))
            c = list(filter(lambda l:l.find('OpType')>=0, c))
            for l in c:
                l = l.split('(')[1]
                l = l.split(')')[0]
                l = l.replace(' ', '')
                l = l.split(',')
                funcName = '___' + l[0] + '__' + l[1] + '__'
                funcNames.append(funcName)
    with open(coremlRegFile, 'w') as f:
        f.write('// This file is generated by Shell for ops register\n')
        f.write('namespace MNN {\n')
        for l in funcNames:
            f.write("extern void " + l + '();\n')
        f.write('\n')
        f.write('void registerNNAPIOps() {\n')
        for l in funcNames:
            f.write(l+'();\n')
        f.write("}\n}\n")

import sys
generateShape(sys.argv[1])
generateCPUFile(sys.argv[1])
generateGeoFile(sys.argv[1])
generateCoreMLFile(sys.argv[1])
generateNNAPIFile(sys.argv[1])
generateOPENCLFile(sys.argv[1])
